【小ネタ】[Amazon SageMaker] 既存のモデルのデプロイをJupyter Notebookでやってみました
1 はじめに
CX事業本部の平内(SIN)です
SageMaker Python SDKには、Modelクラスがあり、これをエンドポイントにデプロイできます。 今回は、Jupyter Notebook上で、過去に作成したモデル(ビルトインの物体検出)をデプロイする要領を確認してみました。
参考:https://sagemaker.readthedocs.io/en/stable/model.html
Jupyter Notebookのサンプルでは、データセットから学習して出来上がったモデルを使用して作業するパターンが殆どで、既存モデルから生成できるModelクラスの使い方に、ちょっと戸惑ったので、その覚書です。
2 Jupyter Notebook
Jupyter Notebookの内容は、以下の通りです。
(1) Setup
最初に、SageMakerのセッションや、ロールを準備します。
import sagemaker from sagemaker import get_execution_role from sagemaker.amazon.amazon_estimator import get_image_uri role = get_execution_role() sess = sagemaker.Session()
(2) Create Model
ここで、Modelクラスのインスタンスを生成しています。
Dockerイメージ(training_image)は、当該モデルを作成したビルトインのobject-detectionです。 また、使用する既存モデル(model.tar.gz)は、S3に配置する必要があります。
ちょっと注意が必要なのは、predictor_clsを指定していないと、deploy()でNoneが返され、推論するための識別子を利用できないことです。
from sagemaker.model import Model from sagemaker.predictor import RealTimePredictor, json_deserializer class ImagePredictor(RealTimePredictor): def __init__(self, endpoint_name, sagemaker_session): super().__init__(endpoint_name, sagemaker_session=sagemaker_session, serializer=None, deserializer=json_deserializer, content_type='image/jpeg') training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest") model_data = 's3://sagemaker-working-bucket/Sweets/output/model.tar.gz' model = Model(role =role,image=training_image,model_data = model_data, predictor_cls=ImagePredictor, sagemaker_session=sess)
(3) Deploy
インスタンスの種類と数を指定してエンドポイントを生成します。
object_detector = model.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')
(4) Detection
下記は、サンプルコードであるawslabs/amazon-sagemaker-examplesと同じで、テスト画像に結果を描画してます。
import json def visualize_detection(img_file, dets, classes=[], thresh=0.1): import random import matplotlib.pyplot as plt import matplotlib.image as mpimg img=mpimg.imread(img_file) plt.imshow(img) height = img.shape[0] width = img.shape[1] colors = dict() for det in dets: (klass, score, x0, y0, x1, y1) = det if score < thresh: continue cls_id = int(klass) if cls_id not in colors: colors[cls_id] = (random.random(), random.random(), random.random()) xmin = int(x0 * width) ymin = int(y0 * height) xmax = int(x1 * width) ymax = int(y1 * height) rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor=colors[cls_id], linewidth=3.5) plt.gca().add_patch(rect) class_name = str(cls_id) if classes and len(classes) > cls_id: class_name = classes[cls_id] plt.gca().text(xmin, ymin - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor=colors[cls_id], alpha=0.5), fontsize=12, color='white') plt.show()
file_name = 'TestData_Sweets/sweet001.png' with open(file_name, 'rb') as image: f = image.read() b = bytearray(f) ne = open('n.txt','wb') ne.write(b) object_detector.content_type = 'image/jpeg' detections = object_detector.predict(b) print(detections)
object_categories = ['BlackThunder','HomePie','Bisco'] threshold = 0.2 visualize_detection(file_name, detections['prediction'], object_categories, threshold)
(5) Delete Endpoint
次のコードでエンドポイントを削除しています。
sagemaker.Session().delete_endpoint(object_detector.endpoint)
3 最後に
今回は、既存のモデルのデプロイをJupyter Notebookでやってみました。何をするにも必要となるModelクラスの利用方法は、しっかり掴みたいと思います。
コードは、下記に起きました。
https://gist.github.com/furuya02/9ecbc1773aff4536f113e2ab8fa6097e